- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
          [mlir][SCF] Allow using a custom operation to generate loops with mlir::tileUsingSCF.
          #159660
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
  
    [mlir][SCF] Allow using a custom operation to generate loops with mlir::tileUsingSCF.
  
  #159660
              Conversation
| @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (MaheshRavishankar) ChangesThis change adds an option to use a custom operation to generate the First one to generate the header of the loop. Presently this is adds support only for tiling. Subsequent commits The PR is split into two commits. The first commit is an NFC that just refactors the code (and cleans up some naming) to make it easier to add the support for custom loop operations. Note that this is duplicate of #159506 that was accidently committed and was reverted in #159598 to wait for reviews. Signed-off-by: MaheshRavishankar [[email protected]](mailto:[email protected]) Patch is 61.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159660.diff 5 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 3205da6e448fc..6b05ade37881c 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction =
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
+  /// Specify which loop construct to use for tile and fuse.
+  enum class LoopType { ForOp, ForallOp, CustomOp };
+  LoopType loopType = LoopType::ForOp;
+  SCFTilingOptions &setLoopType(LoopType type) {
+    loopType = type;
+    return *this;
+  }
+
   /// Computation function that returns the tile sizes to use for each loop.
   /// Returning a tile size of zero implies no tiling for that loop. If the
   /// size of the returned vector is smaller than the number of loops, the inner
@@ -50,6 +58,17 @@ struct SCFTilingOptions {
   /// proper interaction with folding.
   SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
 
+  /// The interchange vector to reorder the tiled loops.
+  SmallVector<int64_t> interchangeVector = {};
+  SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
+    interchangeVector = llvm::to_vector(interchange);
+    return *this;
+  }
+
+  //-------------------------------------------------------------------------//
+  // Options related to tiling using `scf.forall`.
+  //-------------------------------------------------------------------------//
+
   /// Computation function that returns the number of threads to use for
   /// each loop. Returning a num threads of zero implies no tiling for that
   /// loop. If the size of the returned vector is smaller than the number of
@@ -70,21 +89,6 @@ struct SCFTilingOptions {
   /// function that computes num threads at the point they are needed.
   SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
-  /// The interchange vector to reorder the tiled loops.
-  SmallVector<int64_t> interchangeVector = {};
-  SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
-    interchangeVector = llvm::to_vector(interchange);
-    return *this;
-  }
-
-  /// Specify which loop construct to use for tile and fuse.
-  enum class LoopType { ForOp, ForallOp };
-  LoopType loopType = LoopType::ForOp;
-  SCFTilingOptions &setLoopType(LoopType type) {
-    loopType = type;
-    return *this;
-  }
-
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
@@ -117,6 +121,98 @@ struct SCFTilingOptions {
     reductionDims.insert(dims.begin(), dims.end());
     return *this;
   }
+
+  //-------------------------------------------------------------------------//
+  // Options related to tiling using custom loop.
+  //-------------------------------------------------------------------------//
+
+  // For generating the inter-tile loops using a custom loop, two callback
+  // functions are needed
+  // 1. That generates the "loop header", i.e. the loop that iterates over the
+  //    different tiles.
+  // 2. That generates the loop terminator
+  //
+  // For `scf.forall` case the call back to generate loop header would generate
+  //
+  // ```mlir
+  // scf.forall (...) = ... {
+  //   ..
+  // }
+  // ```
+  //
+  // and the call back to generate the loop terminator would generate the
+  // `scf.in_parallel` region
+  //
+  // ```mlir
+  // scf.forall (...) = ... {
+  //   scf.in_parallel {
+  //      tensor.parallel_insert_slice ...
+  //   }
+  // }
+  // ```
+  //
+
+  // Information that is to be returned by the callback to generate the loop
+  // header needed for the rest of the tiled codegeneration.
+  // - `loops`: The generated loops
+  // - `tileOffset`: The values that represent the offset of the iteration space
+  // tile
+  // - `tileSizes` : The values that represent the size of the iteration space
+  // tile.
+  // - `destinationTensors` : The tensors to use as destinations during tiling.
+  struct CustomLoopHeaderInfo {
+    SmallVector<LoopLikeOpInterface> loops;
+    SmallVector<OpFoldResult> tileOffset;
+    SmallVector<OpFoldResult> tileSizes;
+    SmallVector<Value> destinationTensors;
+  };
+
+  // Type of the callback function that generates the loop headers.
+  // - `loopRanges` : Values that represent the full size of the iteration space
+  //                  being tiled.
+  // - `giveTileSizes` : The tile sizes that are to be used to tile the
+  // iteration
+  //                     space.
+  // - `destinationTensors` : The tensors to use as destinations for the results
+  //                          of the tiled loop for loops that implement
+  //                          `DestinationStyleOpInterface`.
+  // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
+  // that this function sets the insertion point of `rewriter` to the program
+  // point where the intra-tile loop computation is to be generated.
+  using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
+      RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+      ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
+
+  // Type of the callback function that generates the loop terminator.
+  // - `tiledResults` : Tiles of the result computed for the iteration space
+  // tile
+  // - `resultOffsets` : For each of the `tiledResults`, the offset at which
+  //                     the result tile is to be "inserted" back into the
+  //                     destination tensor.
+  // - `resultSizes` : For each of the `tiledResults`, the size of the result
+  // tile
+  //                   that is to be "inserted" back into the destination
+  //                   tensor.
+  // Returns the `CustomLoopHeaderInfo` object (described above)
+  using GenerateLoopTerminatorFn = std::function<LogicalResult(
+      RewriterBase &rewriter, Location loc, ValueRange tiledResults,
+      ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> resultSizes,
+      ValueRange destinationTensors)>;
+
+  // Callback function to generate the inter-tile loop header.
+  GenerateLoopHeaderFn generateLoopHeaderFn = nullptr;
+  // Callback function to generate the inter-tile loop terminator.
+  GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr;
+  // Helper function to set the callbacks for inter-tile loop header and
+  // terminator functions when using a custom operation for the loop.
+  SCFTilingOptions &
+  setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn,
+                             GenerateLoopTerminatorFn terminatorFn) {
+    generateLoopHeaderFn = std::move(headerFn);
+    generateLoopTerminatorFn = std::move(terminatorFn);
+    return *this;
+  }
 };
 
 /// Transformation information returned after tiling.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 834c02126fa53..c3899473289e2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -155,18 +155,18 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
 static LogicalResult checkTileSizes(TilingInterface op,
                                     scf::SCFTilingOptions::LoopType loopType,
                                     ReductionTilingStrategy reductionStrategy,
-                                    ArrayRef<OpFoldResult> tileSizes,
+                                    ArrayRef<OpFoldResult> givenTileSizes,
                                     ArrayRef<OpFoldResult> numThreads) {
   auto iterators = op.getLoopIteratorTypes();
-  assert(iterators.size() == tileSizes.size() &&
+  assert(iterators.size() == givenTileSizes.size() &&
          "expected as many tile size values as number of loops");
   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
          "when specified, expected number of threads to use for each loop");
 
   bool isParallelTiling = false;
-  for (auto [index, iterator, tileSize] :
-       llvm::enumerate(iterators, tileSizes)) {
-    if (!isConstantIntValue(tileSize, 0)) {
+  for (auto [index, iterator, givenTileSize] :
+       llvm::enumerate(iterators, givenTileSizes)) {
+    if (!isConstantIntValue(givenTileSize, 0)) {
       isParallelTiling |= iterator == utils::IteratorType::parallel;
     }
 
@@ -186,7 +186,7 @@ static LogicalResult checkTileSizes(TilingInterface op,
       }
 
       if (std::optional<int64_t> constTileSize =
-              getConstantIntValue(tileSize)) {
+              getConstantIntValue(givenTileSize)) {
         if (constTileSize.value() > 0 &&
             iterator != utils::IteratorType::parallel) {
           op.emitWarning() << "tiling is not thread safe at axis #" << index;
@@ -207,11 +207,11 @@ static LogicalResult checkTileSizes(TilingInterface op,
 /// Get the reduction dims that are tiled. This accounts for reduction dims
 /// that are specified as tiled, but the tile size is 0.
 static SetVector<unsigned>
-getSanitizedReductionDims(ArrayRef<OpFoldResult> tileSizes,
+getSanitizedReductionDims(ArrayRef<OpFoldResult> givenTileSizes,
                           const scf::SCFTilingOptions &options) {
   SetVector<unsigned> reductionDims;
   for (auto dim : options.reductionDims) {
-    if (isConstantIntValue(tileSizes[dim], 0))
+    if (isConstantIntValue(givenTileSizes[dim], 0))
       continue;
     reductionDims.insert(dim);
   }
@@ -236,14 +236,14 @@ static bool tileDividesIterationDomain(Range loopRange) {
 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
                                        Range loopRange, OpFoldResult offset,
-                                       OpFoldResult tileSize) {
-  std::optional<int64_t> ts = getConstantIntValue(tileSize);
+                                       OpFoldResult givenTileSize) {
+  std::optional<int64_t> ts = getConstantIntValue(givenTileSize);
   if (ts && ts.value() == 1)
-    return tileSize;
+    return givenTileSize;
 
   if (tileDividesIterationDomain(
-          Range{loopRange.offset, loopRange.size, tileSize}))
-    return tileSize;
+          Range{loopRange.offset, loopRange.size, givenTileSize}))
+    return givenTileSize;
 
   // The tile size to use (to avoid out of bounds access) is  minimum of
   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
@@ -254,15 +254,15 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
   AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
   return affine::makeComposedFoldedAffineMin(
-      b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
+      b, loc, minMap, SmallVector<OpFoldResult>{offset, size, givenTileSize});
 }
 
 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
 /// than `iterationSize`.
-static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize,
                                            OpFoldResult numThreads,
                                            OpFoldResult iterationSize) {
-  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+  std::optional<int64_t> tileSizeConst = getConstantIntValue(givenTileSize);
   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
@@ -274,114 +274,51 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
 /// `offset`s and `size`s of the tile of the iteration space that the
 /// innermost loop body of the generated tiled loops corresponds to.
 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
-getTileOffsetAndSizes(RewriterBase &rewriter, Location loc,
-                      ReductionTilingStrategy strategy, ValueRange ivs,
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
                       ArrayRef<Range> iterationDomain,
-                      ArrayRef<OpFoldResult> tileSizes,
-                      ArrayRef<OpFoldResult> numThreads,
-                      const llvm::SetVector<unsigned> &reductionDims) {
+                      ArrayRef<OpFoldResult> givenTileSizes) {
   SmallVector<OpFoldResult> offsets, sizes;
   int materializedLoopNum = 0;
-
-  if (!numThreads.empty()) {
-    AffineExpr d0, d1, s0, s1;
-    AffineExpr offsetExpr, residualTileSizeExpr;
-    bindDims(rewriter.getContext(), d0, d1);
-    bindSymbols(rewriter.getContext(), s0, s1);
-    offsetExpr = d0 + d1 * s0;
-    residualTileSizeExpr = s1 - (d0 + d1 * s0);
-
-    for (auto [index, nt, tileSize, loopRange] :
-         llvm::enumerate(numThreads, tileSizes, iterationDomain)) {
-
-      // Non-tiled cases, set the offset and size to the
-      // `loopRange.offset/size`.
-      if (isZeroInteger(nt)) {
-        offsets.push_back(loopRange.offset);
-        sizes.push_back(loopRange.size);
-        continue;
-      }
-
-      Value iv = ivs[materializedLoopNum++];
-      OpFoldResult offset = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, offsetExpr,
-          ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
-      OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, residualTileSizeExpr,
-          {loopRange.offset, nt, tileSize, loopRange.size});
-
-      OpFoldResult size = tileSize;
-      if (!isZeroInteger(residualTileSize)) {
-        OpFoldResult sizeMinusOffsetPerThread =
-            affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
-                                                  {offset, loopRange.size});
-        size = affine::makeComposedFoldedAffineMin(
-            rewriter, loc,
-            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
-            {sizeMinusOffsetPerThread, tileSize});
-      }
-
-      // Consider the case where the original loop was `[0, 100)`.
-      // If number of threads are `7`, the tile size would be computed as
-      // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
-      // - `offset = 0 + 6 * 15 = 105`
-      // - `tileSize = min(15, 100 - 105) = -5`
-      // To avoid negative tile sizes, we need to do a further
-      // `nonNegativeTileSize = affine.max(0, tileSize)`.
-      // This `max` can be avoided if
-      //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
-      if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
-        AffineMap maxMap =
-            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
-        size = affine::makeComposedFoldedAffineMax(
-            rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
-      }
-
-      offsets.push_back(offset);
-      sizes.push_back(size);
+  for (auto [givenTileSize, loopRange] :
+       llvm::zip_equal(givenTileSizes, iterationDomain)) {
+
+    // Non-tiled cases, set the offset and size to the
+    // `loopRange.offset/size`.
+    if (isZeroInteger(givenTileSize)) {
+      offsets.push_back(loopRange.offset);
+      sizes.push_back(loopRange.size);
+      continue;
     }
-    return {offsets, sizes};
-  } else {
-    for (auto [tileSize, loopRange] :
-         llvm::zip_equal(tileSizes, iterationDomain)) {
-
-      // Non-tiled cases, set the offset and size to the
-      // `loopRange.offset/size`.
-      if (isZeroInteger(tileSize)) {
-        offsets.push_back(loopRange.offset);
-        sizes.push_back(loopRange.size);
-        continue;
-      }
 
-      Value iv = ivs[materializedLoopNum++];
-      OpFoldResult offset = getAsOpFoldResult(iv);
-      offsets.push_back(offset);
-      OpFoldResult size =
-          getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
-      sizes.push_back(size);
-    }
-    return {offsets, sizes};
+    Value iv = ivs[materializedLoopNum++];
+    OpFoldResult offset = getAsOpFoldResult(iv);
+    offsets.push_back(offset);
+    OpFoldResult size =
+        getBoundedTileSize(rewriter, loc, loopRange, offset, givenTileSize);
+    sizes.push_back(size);
   }
+  return {offsets, sizes};
 }
 
 /// Function to return the bounds of the loops to be generated.
 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
                   SmallVector<OpFoldResult>>
 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-              ArrayRef<OpFoldResult> tileSizes) {
+              ArrayRef<OpFoldResult> givenTileSizes) {
   SmallVector<OpFoldResult> lbs, ubs, steps;
-  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+  for (auto [loopRange, givenTileSize] :
+       llvm::zip_equal(loopRanges, givenTileSizes)) {
     // No loop if the tile size is 0.
-    if (isZeroInteger(tileSize))
+    if (isZeroInteger(givenTileSize))
       continue;
     lbs.push_back(loopRange.offset);
     ubs.push_back(loopRange.size);
-    steps.push_back(tileSize);
+    steps.push_back(givenTileSize);
   }
   return {lbs, ubs, steps};
 }
 
-/// A function that allows returning additional yielded values during
+/// Typedef for function that allows returning additional yielded values during
 /// `yieldTiledValuesAndReplace`.
 /// - `ivs` induction variable for the loop.
 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
@@ -402,6 +339,30 @@ using YieldTiledValuesFn = std::function<LogicalResult(
     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
 
+/// Typedef for function that implements the body of a tiled loop.
+/// - `ivs` induction variable for the loop.
+/// - `tileOffsets` represents offsets for the tiled iteration space.
+/// - `tileSizes` represents the sizes for the tiled iteraiton space.
+/// - `outerDestinationTensors` tensor that holds the result. Is same size
+///   as the destination operands of the original operations.
+/// - `tiledResults` results of the tiled computation, corresponds to
+///   tiles of the original operation computed by the loop body.
+///   Should be same size as the `destinationTensors`
+/// - `resultOffsets` is of the same size as `tiledResults` and represents
+///   the offset to use when writing the corresponding element from
+///   `tiledResults` into `destinationTensors`.
+/// - `resultOffsets` is of the same size as `tiledResults` and represents
+///   the size to use when writing the corresponding element from
+///   `tiledResults` into `destinationTensors`.
+/// In case the method needs to return `failure()` the method is expected
+/// to clean up any inserted operations.
+using GenerateTiledBodyFn = std::function<LogicalResult(
+    RewriterBase &rewriter, Location Loc, ValueRange ivs,
+    ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
+    ValueRange outerDestinationTensors, SmallVector<Value> &tiledResults,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
+
 /// Clones the operation and updates the destination if the operation
 /// implements the `DestinationStyleOpInterface`.
 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
@@ -417,26 +378,25 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
 
 /// Generate the tile-loop nest using `scf.for` operation.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
-/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
-/// - `destinationTensors` are the init values to use for the outer most loop.
-/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
+/// - `givenTileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `outerDestinationTensors` are the init values to use for the outer most
+/// loop.
+/// - `tiledBodyFn` is called to generated the loop body of the inner
 /// most
 ///    loop.
-/// -...
[truncated]
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nits/typos. New callbacks make sense, though see the comment about moving scf.for/forall to use the same mechanism and moving the code out of SCF in the future.
        
          
                mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
          
            Show resolved
            Hide resolved
        
      …ifferent loop types for tiling. Signed-off-by: MaheshRavishankar <[email protected]>
7f636c2    to
    e39fc28      
    Compare
  
    …ir::tileUsingSCF`. This change adds an option to use a custom operation to generate the inter-tile loops during tiling. When the loop type is set to `scf::SCFTilingOptions::LoopType::CustomOp`, the method `mlir::tileUsingSCF` provides two callback functions 1. First one to generate the header of the loop. 2. Second one to generate the terminator of the loop. These methods receive the information needed to generate the loops/terminator and expect to return information needed to generate the code for the intra-tile computation. See comments for more details. Presently this is adds support only for tiling. Subsequent commits will update this to add support for fusion as well. Signed-off-by: MaheshRavishankar <[email protected]>
e39fc28    to
    f41d2ec      
    Compare
  
    
This change adds an option to use a custom operation to generate the
inter-tile loops during tiling. When the loop type is set to
scf::SCFTilingOptions::LoopType::CustomOp, the method
mlir::tileUsingSCF provides two callback functions
First one to generate the header of the loop.
Second one to generate the terminator of the loop.
These methods receive the information needed to generate the
loops/terminator and expect to return information needed to generate
the code for the intra-tile computation. See comments for more
details.
Presently this is adds support only for tiling. Subsequent commits
will update this to add support for fusion as well.
The PR is split into two commits.
The first commit is an NFC that just refactors the code (and cleans up some naming) to make it easier to add the support for custom loop operations.
The second commit adds the support for using a custom loop operation, as well as a test to exercise this path.
Note that this is duplicate of #159506 that was accidently committed and was reverted in #159598 to wait for reviews.
Signed-off-by: MaheshRavishankar [email protected]